Skip to content

Add tree attention backend for v1 (part 1) #20401

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

TheEpicDolphin
Copy link

@TheEpicDolphin TheEpicDolphin commented Jul 2, 2025

Purpose

Add support for tree attention v1 backend. Tree attention is used in EAGLE speculative decoding by the target model to validate a set of draft tokens. Draft tokens only attend to ancestor tokens, and so attention bias must be used to omit attention between non-descendant tokens.

Currently, TreeAttentionImpl is using xformer's tree_attention operation. This operation requires both a prefix and suffix attention bias. The former is used for attention between the draft tokens and the prompt tokens. The latter is used for attention of the draft tokens amongst their ancestors. The two attentions are then merged.

Test Plan

Added test test_tree_attn_correctness which verifies that tree attention output for draft chains exactly matches flash attention for the same number of query tokens, for several configurations. This validates the correctness of this backend.

Benchmark

In addition, I used the following command to run the LLM service and benchmark TreeAttentionBackend vs FlashAttentionBackend:
Server

export VLLM_TORCH_PROFILER_DIR=~/traces/vllm
export LLAMA_MODEL=meta-llama/Llama-3.1-8B-Instruct
export DRAFT_MODEL=yuhuili/EAGLE-LLaMA3.1-Instruct-8B
export VLLM_USE_V1=1
export VLLM_ATTENTION_BACKEND=<backend>
export SPEC_DEC_CONFIG='{"method": "eagle", "model": "'$DRAFT_MODEL'", "num_speculative_tokens": 3, "draft_tensor_parallel_size": 1, "max_model_len": 2048, "speculative_token_tree": "[(0,), (0, 0), (0, 0, 0)]"}'
python -m vllm.entrypoints.openai.api_server --model $LLAMA_MODEL --disable-log-requests --tensor-parallel-size=1 --max-num-seqs=64 --max-model-len=32768 --block-size=128 --no-enable-prefix-caching --speculative-config="$SPEC_DEC_CONFIG" 2>&1 | tee ~/server_logs/vllm_server.log

Client

export LLAMA_MODEL=meta-llama/Llama-3.1-8B-Instruct
python benchmarks/benchmark_serving.py --model $LLAMA_MODEL --tokenizer $LLAMA_MODEL --host 0.0.0.0 --dataset-name random --ignore-eos --request-rate inf --random-input-len 1000 --random-output-len 16 --max-concurrency 64 --num-prompts 64

Results

Serving Benchmark Result Flash Attention Tree Attention
Successful requests 64 64
Benchmark duration (s) 2.17 2.96
Total input tokens 63936 63936
Total generated tokens 1024 1024
Request throughput (req/s) 29.5 21.66
Output token throughput (tok/s) 471.97 346.49
Total Token throughput (tok/s) 29940.28 21980.35
Time to First Token
Mean TTFT (ms) 1069.63 1403.4
Median TTFT (ms) 993.21 1294.13
P99 TTFT (ms) 1881.43 2466.42
Time per Output Token (excl. 1st token)
Mean TPOT (ms) 68.25 93.54
Median TPOT (ms) 71.12 97.59
P99 TPOT (ms) 121.67 161.78
Inter-token Latency
Mean ITL (ms) 72.01 98.47
Median ITL (ms) 20.47 34.55
P99 ITL (ms) 251.47 310.7
SpecDecoding Metrics
Draft acceptance rate 2.40% 2.50%
Mean Acceptance Length 1.07 1.07
Accepted 68 69
Drafted 2796 2790
Per-position Acceptance Rate 0.069, 0.004, 0.000 0.069, 0.004, 0.001

This benchmarking helped me verify that this PR did NOT regress performance on v1 spec decoding.
Improvements still need to be made for tree attention. I will investigate further on how to close the gap.

Manual Testing

Used the code below to send a completion request to the vLLM service running with TREE_ATTN backend:

from openai import OpenAI
client = OpenAI(base_url="http://localhost:8000/v1", api_key="EMPTY")
response = client.chat.completions.create(model="meta-llama/Llama-3.1-8B-Instruct", messages=[{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Explain the theory of relativity in simple terms."}],temperature=0.2)
print(response)

Flash Attention Output

ChatCompletion(id='chatcmpl-6fcc6d98bce64d45b18dc795faa788f5', choices=[Choice(finish_reason='stop', index=0, logprobs=None, message=ChatCompletionMessage(content="The theory of relativity, developed by Albert Einstein, is a fundamental concept in modern physics. I'll break it down in simple terms:\n\n**What is the theory of relativity?**\n\nThe theory of relativity is a way of understanding how the universe works, particularly when it comes to space and time. It's divided into two main parts: special relativity and general relativity.\n\n**Special Relativity (1905)**\n\nSpecial relativity says that how we measure time and space can be different depending on how fast we're moving and where we are. Here are the key points:\n\n1. **Time dilation**: Time can appear to slow down or speed up depending on your speed. The faster you move, the slower time passes.\n2. **Length contraction**: Objects can appear shorter when you're moving really fast.\n3. **The speed of light is always the same**: No matter how fast you're moving, the speed of light remains constant.\n4. **Relativity of simultaneity**: Two events that happen at the same time for one observer might not happen at the same time for another observer in a different state of motion.\n\n**General Relativity (1915)**\n\nGeneral relativity builds on special relativity and adds gravity to the mix. It says that:\n\n1. **Gravity is not a force**: Gravity is actually the curvature of spacetime caused by massive objects.\n2. **Spacetime is flexible**: The presence of massive objects warps spacetime, creating gravitational fields.\n3. **Equivalence principle**: The effects of gravity are equivalent to the effects of acceleration.\n\n**Key Takeaways**\n\nThe theory of relativity revolutionized our understanding of space, time, and gravity. Some of the key implications include:\n\n* Time and space are not absolute, but relative to the observer.\n* The laws of physics are the same everywhere in the universe.\n* Gravity is not a force, but a result of the curvature of spacetime.\n\n**In Simple Terms**\n\nImagine you're on a train, and you throw a ball straight up in the air. To you, on the train, the ball goes straight up and comes straight back down. But to someone watching from the platform, the ball looks like it's moving in a curved path because the train is moving really fast.\n\nThat's kind of like what's happening with time and space in the theory of relativity. The faster you move, the more time and space can appear to change. And gravity is like a big, cosmic curve that warps spacetime, affecting how objects move and interact.\n\nI hope this helps you understand the basics of the theory of relativity!", refusal=None, role='assistant', annotations=None, audio=None, function_call=None, tool_calls=[], reasoning_content=None), stop_reason=None)], created=1752615338, model='meta-llama/Llama-3.1-8B-Instruct', object='chat.completion', service_tier=None, system_fingerprint=None, usage=CompletionUsage(completion_tokens=536, prompt_tokens=52, total_tokens=588, completion_tokens_details=None, prompt_tokens_details=None), prompt_logprobs=None, kv_transfer_params=None)

Tree Attention Output

ChatCompletion(id='chatcmpl-1ff4447ed33e4a91b89fa3f1e25d1b14', choices=[Choice(finish_reason='stop', index=0, logprobs=None, message=ChatCompletionMessage(content="The theory of relativity, developed by Albert Einstein, is a fundamental concept in modern physics. I'll break it down in simple terms:\n\n**What is the theory of relativity?**\n\nThe theory of relativity is a way of understanding how the universe works, particularly when it comes to space and time. It's based on two main ideas: special relativity and general relativity.\n\n**Special Relativity (1905)**\n\nSpecial relativity says that how we measure time and space can be different depending on how fast we're moving and where we are. Here are some key points:\n\n1. **Time dilation**: Time can seem to pass slower for someone moving really fast compared to someone who is standing still.\n2. **Length contraction**: Objects can appear shorter to someone moving really fast compared to someone who is standing still.\n3. **The speed of light is always the same**: No matter how fast you're moving, the speed of light remains the same.\n\n**General Relativity (1915)**\n\nGeneral relativity builds on special relativity and adds a new idea: gravity is not a force, but rather a curvature of space and time caused by massive objects. Here are some key points:\n\n1. **Gravity warps space and time**: The more massive an object is, the more it warps the fabric of space and time around it.\n2. **Gravity is not a force**: Objects don't attract each other with a force called gravity; instead, they follow the curvature of space and time.\n\n**Key Takeaways**\n\n1. **Time and space are relative**: They can be affected by motion and gravity.\n2. **The speed of light is always the same**: It's a universal constant that doesn't change.\n3. **Gravity is a curvature of space and time**: It's not a force, but rather a result of massive objects warping the fabric of the universe.\n\nThe theory of relativity has been extensively tested and confirmed through numerous experiments and observations. It's a fundamental concept in modern physics and has had a profound impact on our understanding of the universe.", refusal=None, role='assistant', annotations=None, audio=None, function_call=None, tool_calls=[], reasoning_content=None), stop_reason=None)], created=1752621554, model='meta-llama/Llama-3.1-8B-Instruct', object='chat.completion', service_tier=None, system_fingerprint=None, usage=CompletionUsage(completion_tokens=423, prompt_tokens=52, total_tokens=475, completion_tokens_details=None, prompt_tokens_details=None), prompt_logprobs=None, kv_transfer_params=None)

Tree Drafts

I tested generating a tree with the following structure:

ROOT
├── 0
│   ├── 0
│   │   └── 0
│   └── 1
│       └── 0
├── 1
│   ├── 0
│   │   └── 0
│   └── 1
│       └── 0
└── 2
    ├── 0
    │   └── 0
    └── 1
        └── 0

Represented by the following list of tuples:

[(0,), (1,), (2,), (0, 0), (0, 1), (1, 0), (1, 1), (2, 0), (2, 1), (0, 0, 0), (0, 1, 0), (1, 0, 0), (1, 1, 0), (2, 0, 0), (2, 1, 0), (0, 0, 0, 0), (0, 1, 0, 0), (1, 0, 0, 0), (1, 1, 0, 0), (2, 0, 0, 0), (2, 1, 0, 0), (0, 0, 0, 0, 0), (0, 1, 0, 0, 0), (1, 0, 0, 0, 0), (1, 1, 0, 0, 0), (2, 0, 0, 0, 0), (2, 1, 0, 0, 0)]

For the input prompt, "Explain the theory of relativity in simple terms.", the backend proposed the following speculative tokens:

"The"
├── " theory"
│   ├── " of"
│   │   └── " rel" ── "ativity" ── ","
│   └── " is"
│       └── " a" ── " fundamental" ── " theory"
├── " Theory"
│   ├── " of"
│   │   └── " Rel" ── "ativity" ── ","
│   └── " Of"
│       └── " Rel" ── "ativity" ── ","
└── " Einstein"
    ├── " Theory"
    │   └── "," ── " " ── " Albert"
    └── "'s"
        └── " Theory" ── " of" ── " Rel"

Each path in the tree sounds c

NOTE: There is currently no way to sample tokens from a tree, so doing this would currently produce gibberish outputs.

TODOs

The following actions still need to be taken to fully enable this backend:

  • Fix paged KV when a tree branch is selected
  • Add general support for setting draft model attention backend. It currently is forced to FlashAttentionBackend.

As of this diff, only chain drafts are supported by TreeAttentionBackend. This is because EagleProposer still only generates draft chains.

Copy link

github-actions bot commented Jul 2, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @TheEpicDolphin, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request integrates the initial phase of a Tree Attention backend into v1 of the attention system, specifically to support EAGLE speculative decoding. The changes enable the efficient validation of draft tokens by implementing a tree-based attention mechanism that correctly applies necessary attention biases. This work involves significant additions to the attention backend infrastructure, updates to model architecture to utilize the new backend, and includes a correctness test to ensure functionality.

Highlights

  • New Tree Attention Backend: Introduced TreeAttentionBackend and TreeAttentionImpl to add support for tree attention, which is a key component for EAGLE speculative decoding in v1 of the attention system.
  • Attention Bias Implementation: The TreeAttentionImpl leverages xformers.ops.tree_attention and correctly applies both prefix and speculative (suffix) attention biases, essential for managing attention between draft tokens and their ancestors or prompt tokens.
  • Dynamic Backend Selection and Draft Model Support: The attention backend selection logic has been updated to include TREE_ATTN and now incorporates an is_draft flag, allowing the system to differentiate and select appropriate attention backends for draft models within the speculative decoding framework.
  • Optimized Batch Processing: A new TreeAttentionMetadataBuilder was added to reorder batches, prioritizing decode requests, and to efficiently construct attention metadata for both prefill (handled by FlashAttention) and speculative decode phases.
  • Correctness Validation: A new test, test_tree_attn_correctness, was implemented to verify the numerical correctness of the TreeAttentionBackend by comparing its output against FlashAttentionBackend across various configurations.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@mergify mergify bot added llama Related to Llama models speculative-decoding v1 labels Jul 2, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new TreeAttentionBackend for speculative decoding, which is a significant feature addition. The implementation is well-structured, reusing FlashAttentionImpl for prefill requests and using xformers for the tree attention part. The new test file provides good coverage for correctness verification.

I've identified a critical issue with duplicated fields in a dataclass and a few medium-severity issues related to code correctness, performance, and maintainability. Addressing these will improve the quality and robustness of the new backend. Overall, this is a great first step towards enabling tree attention.

Comment on lines 85 to 86
block_table: torch.Tensor
slot_mapping: torch.Tensor
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The fields block_table and slot_mapping are defined twice in the TreeAttentionMetadata dataclass. This is likely a copy-paste error and should be corrected.

backends=[FlashAttentionBackend, TreeAttentionBackend],
)
assert torch.allclose(
flash_attn_output, tree_attn_output, atol=1e-2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The absolute tolerance atol=1e-2 is a bit high for bfloat16 tensors, which have a machine epsilon of about 7.81e-3. While this might be necessary due to error accumulation in the attention computation, it would be good to either tighten this tolerance if possible, or add a comment explaining why this level of tolerance is required. A tighter tolerance would give more confidence in the correctness of the implementation.

Comment on lines 172 to 168
# Save for next `build` call
# TODO(lucas): this is a bit of a hack, we should probably have a
# better way of doing this
self._num_decodes = num_decodes
self._num_prefills = num_prefills
self._num_decode_tokens = num_decode_tokens
self._num_prefill_tokens = num_prefill_tokens
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Storing intermediate state like _num_decodes, _num_prefills, etc. on self between calls to reorder_batch and build can make the code harder to reason about and potentially fragile. The TODO comment acknowledges this.

A cleaner approach might be for reorder_batch to return this information, and for the caller (in GPUModelRunner) to pass it to build. This would make the data flow more explicit and improve maintainability.

For example:

# In TreeAttentionMetadataBuilder
def reorder_batch(...) -> tuple[bool, dict[str, Any]]:
    ...
    reorder_info = {
        "num_decodes": num_decodes,
        "num_prefills": num_prefills,
        ...
    }
    return modified_batch, reorder_info

# In GPUModelRunner
modified_batch, reorder_info = self.attn_metadata_builder.reorder_batch(...)
...
self.attn_metadata = self.attn_metadata_builder.build(..., reorder_info=reorder_info)

# In TreeAttentionMetadataBuilder
def build(..., reorder_info: dict[str, Any]) -> TreeAttentionMetadata:
    num_decodes = reorder_info["num_decodes"]
    ...

Since this would require changes outside of this file, this can be addressed in a follow-up PR.

ancestor_idx = []
for c in range(len(cur_tree_choice) - 1):
ancestor_idx.append(
sorted_tree_choices.index(cur_tree_choice[: c + 1]) + 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The use of sorted_tree_choices.index() inside a loop can lead to quadratic complexity with respect to the number of nodes in the tree. While this is likely not an issue for the small trees currently used in speculative decoding, it could become a performance bottleneck if larger or more complex trees are supported in the future.

Consider pre-computing a mapping from path to index to achieve O(1) lookups. For example:

path_to_idx = {path: i for i, path in enumerate(sorted_tree_choices)}
# ... inside the loop ...
ancestor_idx.append(path_to_idx[cur_tree_choice[: c + 1]] + 1)

This would improve the maintainability and future-proof the code against performance issues with larger trees.

@TheEpicDolphin TheEpicDolphin force-pushed the tree_attention_v1 branch 4 times, most recently from 5a37c78 to bfa883a Compare July 2, 2025 21:54
@TheEpicDolphin TheEpicDolphin marked this pull request as ready for review July 2, 2025 22:09
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

Signed-off-by: Giancarlo Delfin <gdelfin@meta.com>
Signed-off-by: Giancarlo Delfin <gdelfin@meta.com>
Signed-off-by: Giancarlo Delfin <gdelfin@meta.com>
Copy link

mergify bot commented Jul 8, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @TheEpicDolphin.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jul 8, 2025
Copy link

@sgrigory sgrigory left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for integrating tree attention! Left a few comments. Regarding the performance, maybe look at the profiles to see what takes the most time - it could be the tree attention itself, but it could also be metadata processing (which we can then take out of decoding loop, at least partially)

device=device,
dtype=torch.int32,
).view(-1, num_allocated_blocks_per_batch)
block_table[:, :num_allocated_blocks_per_batch] = block_ids

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This simulates a situation when pages are actually ordered contiguously in physical memory. Would the test also work in a more complex scenario? For example, you can swap two pages

https://github.com/facebookresearch/xformers/blob/80250b32516b019b72bb44be04ca9a8741b42faa/tests/test_mem_eff_attention.py#L2696-L2699

or even shuffle them all

https://github.com/Dao-AILab/flash-attention/blob/adf27d1db38223288981c4dc3509efafbddd3422/tests/test_flash_attn.py#L2151-L2155

@@ -1442,6 +1442,7 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
"ROCM_AITER_MLA",
"TORCH_SDPA_VLLM_V1",
"FLEX_ATTENTION",
"TREE_ATTN",

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: is the comment above "No XFormers so far" still true if you are importing tree attention from xFormers?

@@ -134,7 +134,7 @@ def _get_sliding_window_configs(
sliding_window_configs: set[Optional[tuple[int, int]]] = set()
layers = get_layers_from_vllm_config(vllm_config, Attention)
for layer in layers.values():
assert isinstance(layer.impl, FlashAttentionImpl)
assert hasattr(layer.impl, "sliding_window")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe assert isinstance(layer.impl, (FlashAttentionImpl, TreeAttentionImpl))?

return depth_counts


def _prepare_tree_attn_bias(

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

spec_v=spec_v,
cache_k=cache_k,
cache_v=cache_v,
prefix_op=triton_splitk.FwOp,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sgrigory
Copy link

cc @bottler

@TheEpicDolphin TheEpicDolphin force-pushed the tree_attention_v1 branch 12 times, most recently from 0e44c6e to 0e691d5 Compare July 15, 2025 23:25
Signed-off-by: Giancarlo Delfin <gdelfin@meta.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants